#include <qthread_slave.h>
#include "swarch.h"
#ifdef __sw_slave__
#include <cassert>
#include <sys/cdefs.h>
#include "esmd_types.h"
#include "cell.h"
#include "dma_funcs.hpp"
#include "dmapp.hpp"
#include "cal.h"
#include <stdio.h>
enum
{
  MP_IN,
  MP_OUT,
  MP_INOUT,
  MP_ZEROOUT
};

template <typename T, int MaxCount, int Mode>
struct memptr_t{
  T data[MaxCount];
  T *mem;
  int count;
  __always_inline memptr_t(T *mem, int count) : count(count), mem(mem) {
    assert(count <= MaxCount);
    if (Mode == MP_IN || Mode == MP_INOUT)
      dma_getn(mem, data, count);
    if (Mode == MP_ZEROOUT) {
      // long *ptr = (long*)data;
      // for (int i = 0; i < count * sizeof(T); i += 8){
      //   data[i] = 0;
      // }
    }
  }
  __always_inline memptr_t(T *mem) : memptr_t(mem, MaxCount) {
  }
  T &operator[](int index){
    return data[index];
  }
  operator T*(){
    return data;
  }
  ~memptr_t(){
    //puts("put");
    if (Mode == MP_OUT || Mode == MP_INOUT || Mode == MP_ZEROOUT) {
      dma_putn(mem, data, count);
    }
  }
};
template<typename T>
struct __array_count_helper {

};
template <typename T, int N>
struct __array_count_helper<T[N]>{
  static constexpr int count = N;
  typedef T base_type;
};

template <typename T> using __array_base_type = typename __array_count_helper<T>::base_type;

template <typename T, int Mode, int Offset>
struct __memptr_for_helper{};
template <typename T, int N, int Mode, int Offset>
struct __memptr_for_helper<T[N], Mode, Offset>{
  typedef memptr_t<T, N - Offset, Mode> memptr_type;
};
template <typename T, int Mode, int Offset=0>
using memptr_for = typename __memptr_for_helper<T, Mode, Offset>::memptr_type;
template<int Offset = 0, typename T>
__always_inline auto array_in(T &ptr, int count) {
  return memptr_for<T, MP_IN>(ptr + Offset, count);
}
template<int Offset = 0, typename T>
__always_inline auto array_out(T &ptr, int count) {
  return memptr_for<T, MP_OUT>(ptr + Offset, count);
}
template<int Offset = 0, typename T>
__always_inline auto array_inout(T &ptr, int count) {
  return memptr_for<T, MP_INOUT>(ptr + Offset, count);
}

template <typename T>
__always_inline T fetch_ptr(T *ptr) {
  T ret;
  dma_getn(ptr, &ret, 1);
  return ret;
}
template <typename T>
__always_inline void list_get(T *mem, T *ldm){
  dma_getn(mem, (char*)ldm, 256);
  if (ldm->effective_size() > 256) {
    dma_getn((char*)mem + 256, (char*)ldm + 256, ldm->effective_size() - 256);
  }
}
template <typename T>
__always_inline void list_put(T *mem, T *ldm) {
  dma_putn(mem, (char*)ldm, ldm->effective_size());
}
#include <tuple>
template<int NBLKS, int BLKSIZE, int WITH_XUC = 0>
struct cell_x_cache {
  //static constexpr int NBLKS = CELL_CAP / BLKSIZE;
  vec<real> data[NBLKS*(1+WITH_XUC)][BLKSIZE];
  int tags[NBLKS];
  celldata_t *cells;
  descriptor<dsc::DMA, dsc::PE, dsc::GET> get;
  static __always_inline int offof(int cell_off, int atom_off) {
    return cell_off * CELL_CAP + atom_off;
  }
  __always_inline cell_x_cache(celldata_t *cells) : cells(cells), nacc(0), nmiss(0), get() {
    for (int j = 0; j < NBLKS; j ++)
      tags[j] = -1;
  }
  __always_inline auto operator()(int cell_off, int atom_off){
    int g_off = cell_off * CELL_CAP + atom_off;
    int g_blk = g_off / BLKSIZE;
    int l_blk = g_blk % NBLKS;
    int off = g_off % BLKSIZE;
    int tag = g_blk;
    if (tags[l_blk] != tag){
      get(cells[cell_off].x + (atom_off & ~(BLKSIZE-1)), data[l_blk], BLKSIZE).syn();
      if (WITH_XUC)
        get(cells[cell_off].shake_tmp + (atom_off & ~(BLKSIZE-1)), data[l_blk + NBLKS], BLKSIZE).syn();
      tags[l_blk] = tag;
      nmiss ++;
    }
    nacc ++;
    return std::get<WITH_XUC>(std::make_tuple(data[l_blk][off], std::make_tuple(data[l_blk][off], data[l_blk + NBLKS * WITH_XUC][off])));
  }
  int nacc, nmiss;
};
template<int NBLKS, int BLKSIZE>
struct cell_f_cache {
  //static constexpr int NBLKS = CELL_CAP / BLKSIZE;
  vec<real> data[NBLKS][BLKSIZE];
  int tags[NBLKS];
  vec<real> *f, tmp[BLKSIZE];
  cal_lock_t *locks;
  descriptor<dsc::DMA, dsc::PE, dsc::GET> get;
  descriptor<dsc::DMA, dsc::PE, dsc::PUT> put;
  static __always_inline int offof(int cell_off, int atom_off) {
    return cell_off * CELL_CAP + atom_off;
  }
  __always_inline cell_f_cache(vec<real> *f, cal_lock_t *locks) : f(f), locks(locks), nacc(0), nmiss(0), get(), put() {
    for (int j = 0; j < NBLKS; j ++)
      tags[j] = -1;
  }
  __always_inline auto &operator()(int cell_off, int atom_off){
    int g_off = cell_off * CELL_CAP + atom_off;
    int g_blk = g_off / BLKSIZE;
    int l_blk = g_blk % NBLKS;
    int off = g_off % BLKSIZE;
    int tag = g_blk;
    if (tags[l_blk] != tag){
      int tag_flush = tags[l_blk];
      if (tag_flush != -1){
        cal_lock(locks + tag_flush);
        //cal_locked_printf("%d\n", tag_flush);
        dma_getn(f + tag_flush * BLKSIZE, tmp, BLKSIZE);//.syn();
        for (int i = 0; i < BLKSIZE; i ++){
          tmp[i] += data[l_blk][i];
        }
        dma_putn(f + tag_flush * BLKSIZE, tmp, BLKSIZE);//.syn();
        cal_unlock(locks + tag_flush);
      }
      //get(cells[cell_off].x + (atom_off & ~(BLKSIZE-1)), data[l_blk], BLKSIZE*sizeof(vec<real>)).syn();
      tags[l_blk] = tag;
      for (int i = 0; i < BLKSIZE; i ++){
        data[l_blk][i] = 0;
      }
      nmiss ++;
    }
    nacc ++;
    return data[l_blk][off];
  }
  __always_inline void fill(cellgrid_t *grid){
    qthread_syn();
    vec<real> f0[CELL_CAP];
    for (int i = 0; i < CELL_CAP; i ++){
      f0[i] = 0;
    }

    FOREACH_CELL_CPE_RR(grid, i, j, k, cell) {
      int offset = get_offset_xyz<true>(grid, i, j, k);
      put(f + offset * CELL_CAP, f0, CELL_CAP);
    }
    qthread_syn();
  }
  __always_inline void flush(){
    for (int i = 0; i < NBLKS; i ++){
      if (tags[i] != -1){
        int tag_flush = tags[i];
        cal_lock(locks + tag_flush);
        get(f + tag_flush * BLKSIZE, tmp, BLKSIZE).syn();
        for (int j = 0; j < BLKSIZE; j ++){
          tmp[j] += data[i][j];
        }
        put(f + tag_flush * BLKSIZE, tmp, BLKSIZE).syn();
        cal_unlock(locks + tag_flush);
      }
    }
  }
  __always_inline void sum(cellgrid_t *grid){
    qthread_syn();
    FOREACH_CELL_CPE_RR(grid, i, j, k, cell) {
      int offset = get_offset_xyz<true>(grid, i, j, k);
      int natom = cell->natom;
      memptr_t<vec<real>, CELL_CAP, MP_INOUT> fcell(cell->f, natom);
      memptr_t<vec<real>, CELL_CAP, MP_IN> frep(f + offset * CELL_CAP, natom);
      for (int i = 0; i < natom; i ++){
        fcell[i] += frep[i];
      }
    }

  }
  __always_inline void sum_rigid(cellgrid_t *grid);
  int nacc, nmiss;
};
template<typename T>
struct copier {
  __always_inline static void copy(T &dst, const T &src){
    dst = src;
  }
};
template<typename T, int N>
struct copier<T[N]> {
  __always_inline static void copy(T *dst, const T *src){
    for (int i = 0; i < N; i ++){
      dst[i] = src[i];
    }
  }
};
template<typename T, int Capacity>
struct list_writer{
  T buf[Capacity];
  int cur;
  size_t accu;
  T *dest;
  list_writer(T *dest) : dest(dest), cur(0), accu(0) {
  }
  
  void append(const T &val) {
    if (cur == Capacity) {
      flush();
    }
    //buf[cur++] = val;
    copier<T>::copy(buf[cur ++], val);
  }
  void flush(){
    dma_putn(dest + accu, buf, cur);
    accu += cur;
    cur = 0;
  }

  ~list_writer(){
    flush();
  }
};
#endif
